Initially, the original SAM paper states that the model could not be fine-tuned, but in examples and the approach we follow here, the goal became to fine-tune the mask decoder to obtain the segmentation of any object we desire. Before the fine-tune process, the details of the dataset and the data loader have been addressed in the 1.2 Datasets and 1.3 DataLoader sections. For fine-tuning, we use the HuggingFace’s Transformers (1.1 HF Transformers SAM’s Fine Tune) Framework. The reason for using this Framework is to facilitate the fine-tuning process easily and effectively and its compatibility to be integrated with the Flower Framework as mentioned in the 1.6 FedSAM section. Also;

The HF Transformers library allows the model to be fine-tuned with the ready weights of the original SAM model instead of using this model. Although many articles related to SAM focus on fine-tuning the original model, in this project, the SAM model of the HF Transformers library has been used. The second reason for using this model is the TensorFlow version of the SAM model. This version allows integration with TensorFlow’s Federated Learning library. The approach followed in the fine-tune section is as follows:

Initially, we use the SamProcessor to load our 1.3 DataLoader. The reason for using the Processor is to enhance model performance and perform various preprocessing steps before feeding images into the SAM model, such as resizing images to a fixed input size, normalizing pixel values, and optionally converting images to a certain format. It also prepares visual prompts like points or bounding boxes that guide the segmentation.

# Initialize the processor
from transformers import SamProcessor
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
  • Following the structure of the model described in the sections [[1.0 On the Segment Anything Model]] and [[1.1 HF Transformers SAM’s Fine Tune]], as indicated in the titles, the model consists of three parts: Vision Encoder, Prompt Encoder, and Mask Decoder. As seen in the code below, we initially freeze the training of the Vision Encoder and Prompt Encoder parts. The reasons for this are:

    1. First, it’s because we want to preserve the model’s pre-trained and rich features.
    2. Since only the Mask Decoder is trained, it allows the model to better adapt to specific segmentation tasks without risking the disruption of the encoders’ general feature extraction capabilities.
    3. Efficiency and Speed can be cited as examples. The number of parameters decreases, significantly speeding up the training.
    4. To prevent overfitting. When the dataset is small, updating all parameters might lead to overfitting, where the model learns noise and specific details instead of generalizing from the training set. Freezing pre-trained components helps reduce this risk by preserving generalized feature extraction capabilities.

from transformers import SamModel

model=SamModel.from_pretrained("./MainDir/UntitledFolder/checkpoint_sam_torch",local_files_only=True)


for name, param in model.named_parameters():

    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):

        param.requires_grad_(False)

  • Loss function and Optimizer; the Dice Coefficient Loss is generally used as the loss function in most segmentation models (e.g., UNET) because it’s often more effective. We also use the Dice Coefficient Loss. For the optimizer, we use the Adam optimizer.
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)

seg_loss = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
  • Then, for training, we used 2 Tesla V100-PCIE-32GBs, so we trained the model in parallel.

  • Next, we move on to training. Here, the values we receive are pixel_values, input_boxes, and ground_truth_masks. These data will be used for loss calculation and in the model’s forward pass.

  • The model performs a forward pass with the received pixel_values and input_boxes data.

outputs = model(pixel_values=pixel_values, input_boxes=input_boxes, multimask_output=False)

  • In the loss calculation part, unnecessary dimensions of the predicted masks are compressed and discarded, and the real segmentation mask is matched with the Dice Coefficient Loss for the loss value.
predicted_masks = outputs.pred_masks.squeeze(1)
loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
  • In the validation part, gradient calculation is reset, and the model is validated on the test dataset without allowing further training. Then, an example image is taken and displayed after each epoch.

Example Validaiton Output

Next Page